CUDA_VISIBLE_DEVICES=0 python main.py --shared layer2 --rotation_type expand \
--group_norm 8 \
             --nepoch 150 --milestone_1 75 --milestone_2 125 \
                                                         --outf results/cifar10_layer2_gn_expand

In [1]:
from __future__ import print_function

import os
import sys
import argparse

import torch
from torch import nn, optim

from ttt.vision.utils.misc import *
from ttt.vision.utils.test_helpers import *
from ttt.vision.utils.prepare_dataset import *
from ttt.vision.utils.rotation import rotate_batch

from tqdm.notebook import tqdm

In [2]:
!nvidia-smi

Mon Nov 25 03:46:51 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  On   | 00000000:04:00.0 Off |                    0 |
| N/A   38C    P0    34W / 250W |   6738MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   37C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Defaul

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [6]:
sys.argv = [
    "python",
    "--shared", "layer2", "--rotation_type", "expand",
    "--group_norm", "8", "--nepoch", "150",
    "--milestone_1", "75", "--milestone_2", "125",
    "--outf", "./results/cifar10_layer2_gn_expand"
]

In [17]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='./datasets/')
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
########################################################################
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--nepoch', default=75, type=int)
parser.add_argument('--milestone_1', default=50, type=int)
parser.add_argument('--milestone_2', default=65, type=int)
parser.add_argument('--rotation_type', default='rand')
########################################################################
parser.add_argument('--outf', default='.')

args = parser.parse_args()

In [18]:
my_makedir(args.outf)
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
net, ext, head, ssh = build_model(args)
_, teloader = prepare_test_data(args)
_, trloader = prepare_train_data(args)

Building model...
Test on the original test set
Files already downloaded and verified
Preparing data...
Files already downloaded and verified


In [19]:
parameters = list(net.parameters())+list(head.parameters())
optimizer = optim.SGD(parameters, lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, [args.milestone_1, args.milestone_2], gamma=0.1, last_epoch=-1)
criterion = nn.CrossEntropyLoss().cuda()

In [20]:
all_err_cls = []
all_err_ssh = []
print('Running...')
print('Error (%)\t\ttest\t\tself-supervised')
for epoch in tqdm(range(1, args.nepoch+1)):
    net.train()
    ssh.train()

    for batch_idx, (inputs, labels) in enumerate(trloader):
        optimizer.zero_grad()
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        outputs_cls = net(inputs_cls)
        loss = criterion(outputs_cls, labels_cls)

        if args.shared is not None:
            inputs_ssh, labels_ssh = rotate_batch(inputs, args.rotation_type)
            inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda()
            outputs_ssh = ssh(inputs_ssh)
            loss_ssh = criterion(outputs_ssh, labels_ssh)
            loss += loss_ssh

        loss.backward()
        optimizer.step()

    err_cls = test(teloader, net)[0]
    err_ssh = 0 if args.shared is None else test(teloader, ssh, sslabel='expand')[0]
    all_err_cls.append(err_cls)
    all_err_ssh.append(err_ssh)
    scheduler.step()

    print(('Epoch %d/%d:' %(epoch, args.nepoch)).ljust(24) +
          '%.2f\t\t%.2f' %(err_cls*100, err_ssh*100))
    torch.save((all_err_cls, all_err_ssh), args.outf + '/loss.pth')
    plot_epochs(all_err_cls, all_err_ssh, args.outf + '/loss.pdf')

Running...
Error (%)		test		self-supervised


  0%|          | 0/150 [00:00<?, ?it/s]

Epoch 1/150:            62.54		44.49
Epoch 2/150:            50.58		39.46
Epoch 3/150:            44.09		35.37
Epoch 4/150:            41.66		32.16
Epoch 5/150:            35.78		28.65
Epoch 6/150:            30.77		27.26
Epoch 7/150:            34.10		27.78
Epoch 8/150:            27.67		22.95
Epoch 9/150:            29.02		23.79
Epoch 10/150:           23.32		22.51
Epoch 11/150:           28.98		23.18
Epoch 12/150:           22.84		21.48
Epoch 13/150:           23.17		22.50
Epoch 14/150:           22.77		20.42
Epoch 15/150:           21.33		20.08
Epoch 16/150:           20.76		19.46
Epoch 17/150:           26.97		19.26
Epoch 18/150:           20.84		18.67
Epoch 19/150:           19.76		18.33
Epoch 20/150:           23.22		20.71
Epoch 21/150:           21.30		18.99
Epoch 22/150:           21.40		17.89
Epoch 23/150:           19.48		18.81
Epoch 24/150:           18.92		18.89
Epoch 25/150:           21.46		19.36
Epoch 26/150:           20.41		18.89
Epoch 27/150:           19.70		19.74
E

In [21]:
state = {'err_cls': err_cls, 'err_ssh': err_ssh,
         'net': net.state_dict(), 'head': head.state_dict(),
         'optimizer': optimizer.state_dict()}
torch.save(state, args.outf + '/ckpt.pth')